{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import json, json5\n",
    "import numpy as np\n",
    "from tqdm import tqdm\n",
    "import glob\n",
    "import pandas as pd\n",
    "import cprint\n",
    "import BboxToolkit as bt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "9350"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "all_classes = ['dam', 'storage-tank', 'ground-track-field', 'overpass', 'baseball-diamond', 'tennis-court', 'vehicle', 'basketball-court', 'golffield', 'harbor', 'expressway-service-area', 'chimney', 'trainstation', 'windmill', 'expressway-toll-station', 'ship', 'airport', 'bridge', 'airplane', 'stadium', 'soccer-ball-field', 'roundabout', 'swimming-pool', 'helicopter', 'container-crane', 'helipad']\n",
    "\n",
    "exclude_phrases = ['flag', 'not provide', 'not specified', 'unknown', 'referred', 'referring',\\\n",
    "    'nose', 'vertical stabilizer', ' tail', 'tail ', 'facing', 'pointing',\\\n",
    "    'first-mentioned', 'aforementioned', 'previously mentioned', 'motion', 'day', 'night', \n",
    "]\n",
    "\n",
    "correct_qa_types = ['object existence', 'object position', 'object quantity', 'object category', 'object color', 'object shape', 'object size', 'object direction', 'scene type', 'image', 'reasoning', 'rural or urban']\n",
    "\n",
    "# all_files = glob.glob('GPT4V_Inference/dior_split_train_0000/*.json') + glob.glob('GPT4V_Inference/dior_split_val_0000/*.json')\n",
    "\n",
    "all_files = glob.glob('./Human_Check/Label_Your_Data_Kaust_DOTA_VAL_cleaned/*.json') + glob.glob('Human_Check/Label_Your_Data_Kaust_DIOR_VAL_cleaned/*.json')\n",
    "\n",
    "all_files = glob.glob('Final_Data/v1.2/Annotations_val/*.json')\n",
    "\n",
    "all_files = [f for f in all_files if not f.endswith('input.json')]\n",
    "\n",
    "len(all_files)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Prepare Referring"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9350/9350 [04:00<00:00, 38.90it/s]"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "valid_count 16159 9350 0\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "all_items = []\n",
    "valid_count = 0\n",
    "non_valid = []\n",
    "\n",
    "use_obb = True\n",
    "angle_all = []\n",
    "\n",
    "for json_path in tqdm(sorted(all_files)):\n",
    "    if json_path.endswith('input.json'):\n",
    "        continue\n",
    "\n",
    "    try:\n",
    "        with open(json_path, 'r') as f:\n",
    "            data = json5.load(f)\n",
    "\n",
    "        if 'objects' in data:\n",
    "            all_refs = data['objects']\n",
    "\n",
    "        # item = {'id': json_path.split('.')[0], 'image': }\n",
    "        conversations = []\n",
    "        for ii, ref in enumerate(all_refs):\n",
    "            if 'referring_sentence' not in ref or 'obj_corner' not in ref or 'obj_cls' not in ref:\n",
    "                print('missing ref info', data['image'])\n",
    "            \n",
    "            bbox = ref['obj_corner']\n",
    "            # convert coordinates to float\n",
    "            bbox = np.array(bbox).astype(np.float32)\n",
    "            hbbs = bt.poly2hbb(bbox)\n",
    "            hbbs = [int(x*100) for x in hbbs]\n",
    "            \n",
    "            obj_cls = ref['obj_cls'].lower()\n",
    "            assert obj_cls in all_classes, f\"invalid class {obj_cls}\"\n",
    "            \n",
    "            bbox_str = \"{\" + f\"<{hbbs[0]}><{hbbs[1]}><{hbbs[2]}><{hbbs[3]}>\" + \"}\"\n",
    "\n",
    "            if use_obb:\n",
    "                # calculate angble beased on the above four cornet points\n",
    "                obbs = bt.bbox2type(bbox, 'obb')\n",
    "                angle = obbs[4]\n",
    "                angle = np.round(np.rad2deg(angle)).astype(np.int32)\n",
    "                angle_all.append(angle)\n",
    "                # use_obb_v1\n",
    "                x,y,w,h = [int(x*100) for x in obbs[:4]]\n",
    "                bbox_str1 = \"{\" + f\"<{x}><{y}><{w}><{h}>|<{angle}>\" + \"}\"\n",
    "\n",
    "                # use_obb_v2\n",
    "                corners = bbox.reshape(4, 2)\n",
    "                dist = corners[:,0]**2 + corners[:,1]**2\n",
    "                tl = (corners[np.argmin(dist)] * 100).astype('int32')\n",
    "                br = (corners[np.argmax(dist)] * 100).astype('int32')\n",
    "                bbox_str2 = \"{\" + f\"<{tl[0]}><{tl[1]}><{br[0]}><{br[1]}>|<{angle}>\" + \"}\"\n",
    "\n",
    "            obj_size = ref['obj_size']\n",
    "            \n",
    "            item_dict = {\n",
    "                \"image_id\": data['image'], \n",
    "                \"question\": ref['referring_sentence'], \n",
    "                \"ground_truth\": bbox_str,\n",
    "                \"dataset\": \"RSBench\", \n",
    "                \"question_id\": valid_count, \n",
    "                \"type\": \"ref\", \n",
    "                \"unique\": ref['is_unique'],\n",
    "                \"obj_corner\": ref['obj_corner'],\n",
    "                \"obj_cls\": obj_cls,\n",
    "                \"obj_ids\": [0, ], \n",
    "                \"size_group\": obj_size,\n",
    "            }\n",
    "            if use_obb:\n",
    "                item_dict.update({\n",
    "                    \"obb_v1\": bbox_str1,\n",
    "                    \"obb_v2\": bbox_str2,\n",
    "                })\n",
    "                \n",
    "            all_items.append(item_dict)\n",
    "        \n",
    "            valid_count += 1\n",
    "\n",
    "    except Exception as e:\n",
    "        print('skipping', json_path, 'error', e)\n",
    "        non_valid.append(json_path)\n",
    "        continue\n",
    "\n",
    "# with open('RSBench_EVAL_referring_obb.json', 'w') as f:\n",
    "#     json.dump(all_items, f, indent=4)\n",
    "\n",
    "print('valid_count', valid_count, len(all_files), len(non_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(-90, -12.368277739959156, 90)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "np.min(angle_all), np.mean(angle_all), np.max(angle_all)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inflect\n",
    "convert = inflect.engine()\n",
    "all_numbers = [convert.number_to_words(x) for x in range(100)]\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# VQA"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9350/9350 [04:01<00:00, 38.69it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "valid_count 37409 9350 0 0\n"
     ]
    }
   ],
   "source": [
    "\n",
    "all_items = []\n",
    "valid_count = 0\n",
    "non_valid = []\n",
    "\n",
    "use_obb = False\n",
    "skip_qas = 0\n",
    "for json_path in tqdm(sorted(all_files)):\n",
    "    if json_path.endswith('input.json'):\n",
    "        continue\n",
    "\n",
    "    try:\n",
    "        with open(json_path, 'r') as f:\n",
    "            data = json5.load(f)\n",
    "\n",
    "        if 'qa_pairs' in data:\n",
    "            all_qas = data['qa_pairs']\n",
    "\n",
    "        for jj, qa in enumerate(all_qas):\n",
    "            if 'question' not in qa or 'answer' not in qa or 'type' not in qa:\n",
    "                print('missing qa info', data['image'])\n",
    "                continue\n",
    "            \n",
    "            skip = False\n",
    "            for phr in exclude_phrases + ['source', 'resolution']:\n",
    "                if phr in qa['answer'] or phr in qa['question']:\n",
    "                    print('skipping qa', json_path, 'due to', phr)\n",
    "                    skip = True\n",
    "                    skip_qas += 1\n",
    "                    break\n",
    "            if skip: continue\n",
    "            \n",
    "            skip = False\n",
    "            if 'which' in qa['question'].lower():\n",
    "                for num in range(100):\n",
    "                    if str(num) in qa['answer'] or all_numbers[num] in qa['answer']:\n",
    "                        skip = True\n",
    "                        skip_qas += 1\n",
    "                        print('question by object order, '+ qa['answer'])\n",
    "                        break\n",
    "            if skip: continue\n",
    "\n",
    "            qa['type'] = qa['type'].strip().lower()\n",
    "            assert qa['type'] in correct_qa_types, qa['type']\n",
    "\n",
    "            item_dict = {\n",
    "                \"image_id\": data['image'], \n",
    "                \"question\": qa['question'], \n",
    "                \"ground_truth\": qa['answer'], \n",
    "                \"dataset\": \"RSBench\", \n",
    "                \"question_id\": valid_count, \n",
    "                \"type\": qa['type'], \n",
    "            }\n",
    "            all_items.append(item_dict)\n",
    "        \n",
    "            valid_count += 1\n",
    "\n",
    "    except Exception as e:\n",
    "        print('skipping', json_path, 'error', e)\n",
    "        non_valid.append(json_path)\n",
    "        continue\n",
    "\n",
    "with open('RSBench_EVAL_vqa_v2.json', 'w') as f:\n",
    "    json.dump(all_items, f, indent=4)\n",
    "\n",
    "print('valid_count', valid_count, len(all_files), len(non_valid), skip_qas)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "(37409, 0)"
      ]
     },
     "execution_count": 19,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "valid_count, skip_qas"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Captioning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 9350/9350 [04:02<00:00, 38.49it/s]\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "valid_count 9350 9350 0\n"
     ]
    }
   ],
   "source": [
    "\n",
    "all_items = []\n",
    "valid_count = 0\n",
    "non_valid = []\n",
    "\n",
    "use_obb = False\n",
    "\n",
    "for json_path in tqdm(sorted(all_files)):\n",
    "    if json_path.endswith('input.json'):\n",
    "        continue\n",
    "\n",
    "    try:\n",
    "        with open(json_path, 'r') as f:\n",
    "            data = json5.load(f)\n",
    "\n",
    "        if 'caption' in data:\n",
    "            caption = data['caption']\n",
    "            item_dict = {\n",
    "                \"image_id\": data['image'], \n",
    "                \"ground_truth\": caption, \n",
    "                \"question\": \"Describe the image in detail\",\n",
    "                \"dataset\": \"RSBench\", \n",
    "                \"question_id\": valid_count, \n",
    "                \"type\": \"caption\", \n",
    "            }\n",
    "            all_items.append(item_dict)\n",
    "        \n",
    "            valid_count += 1\n",
    "\n",
    "    except Exception as e:\n",
    "        print('skipping', json_path, 'error', e)\n",
    "        non_valid.append(json_path)\n",
    "        continue\n",
    "\n",
    "with open('RSBench_EVAL_Cap_v2.json', 'w') as f:\n",
    "    json.dump(all_items, f, indent=4)\n",
    "\n",
    "print('valid_count', valid_count, len(all_files), len(non_valid))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "base",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
